#!/usr/bin/env python3
#
# Test file I/O by loading from restart files and writing to dump files
#
# require: netcdf

from boutdata import restart
from boutdata.collect import collect
from boututils.boutarray import BoutArray
from boututils.datafile import DataFile
from boututils.run_wrapper import shell, shell_safe, launch_safe
import numpy
import os
from sys import exit

nx = 8
ny = 16
nz = 4
mxg = 2
myg = 2

print("Making restart I/O test")
shell_safe("make > make.log")

x = numpy.linspace(0., 1., nx+2*mxg)[:, numpy.newaxis, numpy.newaxis]
y = numpy.linspace(0., 1., ny+2*myg)[numpy.newaxis, :, numpy.newaxis]
z = numpy.linspace(0., 1., nz)[numpy.newaxis, numpy.newaxis, :]

testvars = {}
testvars['f3d'] = BoutArray(numpy.exp(numpy.sin(x + y + z)), attributes = {'bout_type':'Field3D'})
testvars['f2d'] = BoutArray(numpy.exp(numpy.sin(x + y + 1.))[:, :, 0], attributes = {'bout_type':'Field2D'})
testvars['fperp_lower'] = BoutArray(numpy.exp(numpy.sin(x + z + 2.))[:, 0, :], attributes = {'bout_type':'FieldPerp', 'yindex_global':0})
testvars['fperp_upper'] = BoutArray(numpy.exp(numpy.sin(x + z + 3.))[:, 0, :], attributes = {'bout_type':'FieldPerp', 'yindex_global':16})

# make restart file
restartdir = os.path.join('data', 'restart')
try:
    os.mkdir(restartdir)
except FileExistsError:
    pass

with DataFile(os.path.join(restartdir, 'BOUT.restart.0.nc'), create=True) as base_restart:
    base_restart.write('MXSUB', nx)
    base_restart.write('MYSUB', ny)
    base_restart.write('MZSUB', nz)
    base_restart.write('MXG', mxg)
    base_restart.write('MYG', myg)
    base_restart.write('MZG', 0)
    base_restart.write('nx', nx+2*mxg)
    base_restart.write('ny', ny)
    base_restart.write('nz', nz)
    base_restart.write('MZ', nz)
    base_restart.write('NXPE', 1)
    base_restart.write('NYPE', 1)
    base_restart.write('NZPE', 1)
    base_restart.write('tt', 0.)
    base_restart.write('hist_hi', 0)
    # set BOUT_VERSION to stop collect from changing nz or printing a warning
    base_restart.write('BOUT_VERSION', 4.)
    base_restart.write('f3d', testvars['f3d'])
    base_restart.write('f2d', testvars['f2d'])
    base_restart.write('fperp_lower', testvars['fperp_lower'])
    base_restart.write('fperp_upper', testvars['fperp_upper'])

success = True

# Note: expect this to fail for 16 processors, because when there are 2
# y-points per processor, the fperp_lower FieldPerp is in the grid cells of one
# set of processors and also in the guard cells of anoether set. This means
# valid FieldPerps get written to output files with different
# y-processor-indices, and collect() cannot handle this.
for nproc in [1, 2, 4]:
    # delete any existing output
    shell("rm data/BOUT.dmp.*.nc data/BOUT.restart.*.nc")

    # create restart files for the run
    restart.redistribute(nproc, path=restartdir, output='data')

    print("   %d processor...." % (nproc))

    # run the test executable
    s, out = launch_safe('./test-restart-io', nproc=nproc, pipe=True)
    with open("run.log."+str(nproc), "w") as f:
        f.write(out)

    # check the results
    for name in testvars.keys():
        # check non-evolving version
        result = collect(name+"_once", path='data', xguards=True, yguards=True, info=False)
        testvar = testvars[name]

        if not numpy.all(testvar == result):
            success = False
            print(name+' is different')
            from boututils.showdata import showdata
            showdata([result, testvar])
        if name == 'fperp_lower' or name == 'fperp_upper':
            yindex_result = result.attributes['yindex_global']
            yindex_test = testvar.attributes['yindex_global']
            if not yindex_result == yindex_test:
                success = False
                print('Fail: yindex_global of '+name+' is '+str(yindex_result)+' should be '+str(yindex_test))

        # check evolving versions
        result = collect(name, path='data', xguards=True, yguards=True, info=False)

        for result_timeslice in result:
            if not numpy.all(testvar == result_timeslice):
                success = False
                print(name+' evolving version is different')
            if name == 'fperp_lower' or name == 'fperp_upper':
                yindex_result = result.attributes['yindex_global']
                yindex_test = testvar.attributes['yindex_global']
                if not yindex_result == yindex_test:
                    success = False
                    print('Fail: yindex_global of '+name+' evolving version is '+str(yindex_result)+' should be '+str(yindex_test))

if success:
    print('pass')

    # clean up binary files
    shell("rm data/BOUT.dmp.*.nc data/BOUT.restart.*.nc data/restart/BOUT.restart.0.nc")

exit(0)
